bb4e68
@@ -17,7 +17,9 @@
 package org.springframework.web.socket.server.standard;
 
 import java.lang.reflect.Constructor;
+import java.lang.reflect.InvocationHandler;
 import java.lang.reflect.Method;
+import java.lang.reflect.Proxy;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.List;
@@ -31,7 +33,6 @@
import javax.websocket.Endpoint;
 import javax.websocket.Extension;
 import javax.websocket.server.ServerEndpointConfig;
 
-import io.undertow.server.HttpServerExchange;
 import io.undertow.server.HttpUpgradeListener;
 import io.undertow.servlet.api.InstanceFactory;
 import io.undertow.servlet.api.InstanceHandle;
@@ -50,7 +51,6 @@
import io.undertow.websockets.jsr.handshake.JsrHybi07Handshake;
 import io.undertow.websockets.jsr.handshake.JsrHybi08Handshake;
 import io.undertow.websockets.jsr.handshake.JsrHybi13Handshake;
 import io.undertow.websockets.spi.WebSocketHttpExchange;
-import org.xnio.StreamConnection;
 
 import org.springframework.http.server.ServerHttpRequest;
 import org.springframework.http.server.ServerHttpResponse;
@@ -122,8 +122,7 @@
public class UndertowRequestUpgradeStrategy extends AbstractStandardUpgradeStrat
 
 			// Adapting between different Pool API types in Undertow 1.0-1.2 vs 1.3
 			getBufferPoolMethod = WebSocketHttpExchange.class.getMethod("getBufferPool");
-			createChannelMethod = Handshake.class.getMethod("createChannel",
-					WebSocketHttpExchange.class, StreamConnection.class, getBufferPoolMethod.getReturnType());
+			createChannelMethod = ReflectionUtils.findMethod(Handshake.class, "createChannel", (Class<?>[]) null);
 		}
 		catch (Throwable ex) {
 			throw new IllegalStateException("Incompatible Undertow API version", ex);
@@ -162,31 +161,39 @@
public class UndertowRequestUpgradeStrategy extends AbstractStandardUpgradeStrat
 
 		HttpServletRequest servletRequest = getHttpServletRequest(request);
 		HttpServletResponse servletResponse = getHttpServletResponse(response);
-
 		final ServletWebSocketHttpExchange exchange = createHttpExchange(servletRequest, servletResponse);
 		exchange.putAttachment(HandshakeUtil.PATH_PARAMS, Collections.<String, String>emptyMap());
 
 		ServerWebSocketContainer wsContainer = (ServerWebSocketContainer) getContainer(servletRequest);
 		final EndpointSessionHandler endpointSessionHandler = new EndpointSessionHandler(wsContainer);
-
 		final ConfiguredServerEndpoint configuredServerEndpoint = createConfiguredServerEndpoint(
 				selectedProtocol, selectedExtensions, endpoint, servletRequest);
-
 		final Handshake handshake = getHandshakeToUse(exchange, configuredServerEndpoint);
 
-		exchange.upgradeChannel(new HttpUpgradeListener() {
-			@Override
-			public void handleUpgrade(StreamConnection connection, HttpServerExchange serverExchange) {
-				Object bufferPool = ReflectionUtils.invokeMethod(getBufferPoolMethod, exchange);
-				WebSocketChannel channel = (WebSocketChannel) ReflectionUtils.invokeMethod(
-						createChannelMethod, handshake, exchange, connection, bufferPool);
-				if (peerConnections != null) {
-					peerConnections.add(channel);
-				}
-				endpointSessionHandler.onConnect(exchange, channel);
-			}
-		});
-
+		HttpUpgradeListener upgradeListener = (HttpUpgradeListener) Proxy.newProxyInstance(
+				getClass().getClassLoader(), new Class<?>[] {HttpUpgradeListener.class},
+				new InvocationHandler() {
+					@Override
+					public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
+						if ("handleUpgrade".equals(method.getName())) {
+							Object connection = args[0];  // currently an XNIO StreamConnection
+							Object bufferPool = ReflectionUtils.invokeMethod(getBufferPoolMethod, exchange);
+							WebSocketChannel channel = (WebSocketChannel) ReflectionUtils.invokeMethod(
+									createChannelMethod, handshake, exchange, connection, bufferPool);
+							if (peerConnections != null) {
+								peerConnections.add(channel);
+							}
+							endpointSessionHandler.onConnect(exchange, channel);
+							return null;
+						}
+						else {
+							// any java.lang.Object method: equals, hashCode, toString...
+							return ReflectionUtils.invokeMethod(method, this, args);
+						}
+					}
+				});
+
+		exchange.upgradeChannel(upgradeListener);
 		handshake.handshake(exchange);
 	}
 
